{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "7652b16c",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/zhengzitao/anaconda3/envs/02toxic/lib/python3.10/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "from utils.stratify import stratified_train_test_split\n",
    "from torch.utils.data import Dataset, DataLoader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "d3deacc1",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"\")\n",
    "from model import clf_model\n",
    "from utils.data_builder import data_builder\n",
    "import numpy as np\n",
    "\n",
    "from matplotlib import pyplot as plt\n",
    "import seaborn as sns\n",
    "\n",
    "\n",
    "from sklearn.metrics import classification_report\n",
    "from sklearn.metrics import roc_auc_score\n",
    "from sklearn.metrics import confusion_matrix\n",
    "\n",
    "from memory_profiler import profile\n",
    "import time\n",
    "\n",
    "\n",
    "EPOCHES = 1\n",
    "X_PATH = \"data/sentence_codes_4096_dm0.npy\" # 预训练模型\n",
    "ORIG_PATH = \"data/train.csv\"\n",
    "LABLE_NAME = [\"\", \"\", \"toxic\",\"severe_toxic\",\"obscene\",\"threat\",\"insult\",\"identity_hate\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "0b65895c",
   "metadata": {},
   "outputs": [],
   "source": [
    "builder = data_builder(X_PATH, ORIG_PATH)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "863e4fa2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The execution time of the function 'prepare_data' is 9.389770s\n"
     ]
    }
   ],
   "source": [
    "data, label = builder.prepare_data()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "ac0f5f8a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Starting score: 283. Calculated in 0:00:00\n",
      "Epoch 1/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 2/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 3/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 4/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 5/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 6/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 7/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 8/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 9/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 10/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 11/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 12/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 13/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 14/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 15/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 16/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 17/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 18/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 19/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 20/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 21/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 22/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 23/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 24/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 25/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 26/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 27/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 28/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 29/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 30/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 31/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 32/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 33/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 34/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 35/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 36/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 37/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 38/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 39/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 40/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 41/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 42/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 43/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 44/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 45/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 46/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 47/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 48/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 49/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 50/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 51/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 52/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 53/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 54/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 55/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 56/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 57/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 58/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 59/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 60/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 61/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 62/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 63/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 64/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 65/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 66/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 67/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 68/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 69/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 70/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 71/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 72/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 73/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 74/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 75/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 76/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 77/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 78/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 79/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 80/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 81/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 82/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 83/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 84/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 85/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 86/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 87/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 88/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 89/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 90/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 91/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 92/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 93/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 94/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 95/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 96/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 97/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 98/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 99/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 100/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 101/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 102/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 103/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 104/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 105/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 106/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 107/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 108/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 109/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 110/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 111/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 112/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 113/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 114/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 115/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 116/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 117/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 118/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 119/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 120/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 121/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 122/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 123/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 124/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 125/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 126/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 127/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 128/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 129/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 130/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 131/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 132/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 133/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 134/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 135/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 136/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 137/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 138/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 139/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 140/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 141/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 142/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 143/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 144/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 145/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 146/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 147/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 148/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 149/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 150/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 151/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 152/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 153/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 154/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 155/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 156/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 157/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 158/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 159/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 160/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 161/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 162/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 163/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 164/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 165/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 166/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 167/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 168/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 169/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 170/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 171/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 172/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 173/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 174/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 175/500 score: 283. Calculated in 0:00:00\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 176/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 177/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 178/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 179/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 180/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 181/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 182/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 183/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 184/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 185/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 186/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 187/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 188/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 189/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 190/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 191/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 192/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 193/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 194/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 195/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 196/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 197/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 198/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 199/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 200/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 201/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 202/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 203/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 204/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 205/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 206/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 207/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 208/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 209/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 210/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 211/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 212/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 213/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 214/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 215/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 216/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 217/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 218/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 219/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 220/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 221/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 222/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 223/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 224/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 225/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 226/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 227/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 228/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 229/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 230/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 231/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 232/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 233/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 234/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 235/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 236/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 237/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 238/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 239/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 240/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 241/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 242/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 243/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 244/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 245/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 246/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 247/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 248/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 249/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 250/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 251/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 252/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 253/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 254/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 255/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 256/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 257/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 258/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 259/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 260/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 261/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 262/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 263/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 264/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 265/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 266/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 267/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 268/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 269/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 270/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 271/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 272/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 273/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 274/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 275/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 276/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 277/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 278/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 279/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 280/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 281/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 282/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 283/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 284/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 285/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 286/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 287/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 288/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 289/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 290/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 291/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 292/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 293/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 294/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 295/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 296/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 297/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 298/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 299/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 300/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 301/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 302/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 303/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 304/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 305/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 306/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 307/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 308/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 309/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 310/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 311/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 312/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 313/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 314/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 315/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 316/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 317/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 318/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 319/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 320/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 321/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 322/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 323/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 324/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 325/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 326/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 327/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 328/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 329/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 330/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 331/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 332/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 333/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 334/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 335/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 336/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 337/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 338/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 339/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 340/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 341/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 342/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 343/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 344/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 345/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 346/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 347/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 348/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 349/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 350/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 351/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 352/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 353/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 354/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 355/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 356/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 357/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 358/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 359/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 360/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 361/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 362/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 363/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 364/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 365/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 366/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 367/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 368/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 369/500 score: 283. Calculated in 0:00:00\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 370/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 371/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 372/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 373/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 374/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 375/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 376/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 377/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 378/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 379/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 380/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 381/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 382/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 383/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 384/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 385/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 386/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 387/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 388/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 389/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 390/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 391/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 392/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 393/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 394/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 395/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 396/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 397/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 398/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 399/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 400/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 401/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 402/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 403/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 404/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 405/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 406/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 407/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 408/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 409/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 410/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 411/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 412/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 413/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 414/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 415/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 416/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 417/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 418/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 419/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 420/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 421/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 422/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 423/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 424/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 425/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 426/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 427/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 428/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 429/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 430/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 431/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 432/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 433/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 434/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 435/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 436/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 437/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 438/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 439/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 440/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 441/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 442/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 443/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 444/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 445/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 446/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 447/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 448/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 449/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 450/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 451/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 452/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 453/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 454/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 455/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 456/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 457/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 458/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 459/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 460/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 461/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 462/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 463/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 464/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 465/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 466/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 467/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 468/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 469/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 470/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 471/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 472/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 473/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 474/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 475/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 476/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 477/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 478/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 479/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 480/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 481/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 482/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 483/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 484/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 485/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 486/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 487/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 488/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 489/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 490/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 491/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 492/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 493/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 494/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 495/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 496/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 497/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 498/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 499/500 score: 283. Calculated in 0:00:00\n",
      "Epoch 500/500 score: 283. Calculated in 0:00:00\n",
      "To train: 0\n",
      "To test: 0\n",
      "Target test size: 0.2\n",
      "Actual test size: 0.186\n"
     ]
    }
   ],
   "source": [
    "train_X, test_X, train_y, test_y = stratified_train_test_split(data, label, 0.2, epochs=500, random_state=1234)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "c2427ae2",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_data = np.concatenate((train_X, train_y), 1)\n",
    "train_loader = DataLoader(train_data, batch_size=64, shuffle=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "b5f9918a",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "aka = None\n",
    "for data in train_loader:\n",
    "    aka = data\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "d6da0ba7",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "a8aa1953",
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_X = data[:, 4096].type(torch.float32)\n",
    "batch_y = data[:, 4096:].reshape(-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "c44496e7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-0.0837,  0.0370, -0.0497,  ..., -0.0250,  0.2157, -0.1312],\n",
       "        [-0.0926,  0.1282, -0.1224,  ..., -0.0674,  0.0173,  0.0697],\n",
       "        [-0.0678, -0.1788, -0.0398,  ..., -0.0875,  0.0498, -0.0067],\n",
       "        ...,\n",
       "        [-0.0439, -0.1531, -0.0295,  ..., -0.2182, -0.0409,  0.1715],\n",
       "        [-0.1809,  0.1411,  0.0190,  ...,  0.0162, -0.0068, -0.1859],\n",
       "        [-0.0676,  0.0490, -0.1412,  ..., -0.0590,  0.0228, -0.3074]],\n",
       "       dtype=torch.float64)"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "kk = np.array(()) "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9ff70f39",
   "metadata": {},
   "source": [
    "# Main"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "36e9cf81",
   "metadata": {},
   "outputs": [],
   "source": [
    "EPOCHES = 1\n",
    "X_PATH = \"data/sentence_codes_4096_dm0.npy\" # 预训练模型\n",
    "ORIG_PATH = \"data/train.csv\"\n",
    "LABLE_NAME = [\"toxic\",\"severe_toxic\",\"obscene\",\"threat\",\"insult\",\"identity_hate\"]\n",
    "\n",
    "def build_model_list():\n",
    "    model_list = []\n",
    "    for name in LABLE_NAME:\n",
    "        f = open('out/{}.pickle'.format(name),'rb') \n",
    "        s = f.read()\n",
    "        model_list.append(pickle.loads(s))\n",
    "\n",
    "    return model_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "523be8a9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The execution time of the function 'prepare_data' is 17.616019s\n"
     ]
    }
   ],
   "source": [
    "builder = data_builder(X_PATH, ORIG_PATH)\n",
    "train_X, train_y = builder.build_all()\n",
    "small_model_list = build_model_list()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "27888da8",
   "metadata": {},
   "outputs": [],
   "source": [
    "from model import multi_model\n",
    "model = multi_model(small_model_list, EPOCHES)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "48fc4dfa",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "kkkkk\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(100, 6)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y_pred = model.predict_proba(train_X)\n",
    "y_pred.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "f6674b01",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([-0.00032226, -0.00040472, -0.00022585, -0.00021952, -0.00022862,\n",
       "       -0.00022313], dtype=float32)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y_pred[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "9574fca1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0, 0, 0, 0, 0, 0])"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_y[0]"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:02toxic]",
   "language": "python",
   "name": "conda-env-02toxic-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  },
  "vscode": {
   "interpreter": {
    "hash": "44228eae3222851a050825964598a19668a841865d547ce1c3c6b7890971bdf0"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
